UNPKG

@lobehub/chat

Version:

Lobe Chat - an open-source, high-performance chatbot framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.

230 lines (188 loc) 7.39 kB
// @vitest-environment node import { getAuth } from '@clerk/nextjs/server'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { checkAuthMethod } from '@/app/(backend)/middleware/auth/utils'; import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth'; import { AgentRuntime, LobeRuntimeAI } from '@/libs/model-runtime'; import { ChatErrorType } from '@/types/fetch'; import { getJWTPayload } from '@/utils/server/jwt'; import { POST } from './route'; vi.mock('@clerk/nextjs/server', () => ({ getAuth: vi.fn(), })); vi.mock('@/app/(backend)/middleware/auth/utils', () => ({ checkAuthMethod: vi.fn(), })); vi.mock('@/utils/server/jwt', () => ({ getJWTPayload: vi.fn(), })); // 定义一个变量来存储 enableAuth 的值 let enableClerk = false; // 模拟 @/const/auth 模块 vi.mock('@/const/auth', async (importOriginal) => { const modules = await importOriginal(); return { ...(modules as any), get enableClerk() { return enableClerk; }, }; }); // 模拟请求和响应 let request: Request; beforeEach(() => { request = new Request(new URL('https://test.com'), { headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token', [OAUTH_AUTHORIZED]: 'true', }, method: 'POST', body: JSON.stringify({ model: 'test-model' }), }); }); afterEach(() => { // 清除模拟调用历史 vi.clearAllMocks(); enableClerk = false; }); describe('POST handler', () => { describe('init chat model', () => { it('should initialize AgentRuntime correctly with valid authorization', async () => { const mockParams = Promise.resolve({ provider: 'test-provider' }); // 设置 getJWTPayload 和 initAgentRuntimeWithUserPayload 的模拟返回值 vi.mocked(getJWTPayload).mockResolvedValueOnce({ accessCode: 'test-access-code', apiKey: 'test-api-key', azureApiVersion: 'v1', }); const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() }; // migrate to new AgentRuntime init api const spy = vi .spyOn(AgentRuntime, 'initializeWithProvider') .mockResolvedValue(new AgentRuntime(mockRuntime)); // 调用 POST 函数 await POST(request as unknown as Request, { params: mockParams }); // 验证是否正确调用了模拟函数 expect(getJWTPayload).toHaveBeenCalledWith('Bearer some-valid-token'); expect(spy).toHaveBeenCalledWith('test-provider', expect.anything()); }); it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => { const mockParams = Promise.resolve({ provider: 'test-provider' }); const requestWithoutAuthHeader = new Request(new URL('https://test.com'), { method: 'POST', body: JSON.stringify({ model: 'test-model' }), }); const response = await POST(requestWithoutAuthHeader, { params: mockParams }); expect(response.status).toBe(401); expect(await response.json()).toEqual({ body: { error: { errorType: 401 }, provider: 'test-provider', }, errorType: 401, }); }); it('should have pass clerk Auth when enable clerk', async () => { enableClerk = true; vi.mocked(getJWTPayload).mockResolvedValueOnce({ accessCode: 'test-access-code', apiKey: 'test-api-key', azureApiVersion: 'v1', }); const mockParams = Promise.resolve({ provider: 'test-provider' }); // 设置 initAgentRuntimeWithUserPayload 的模拟返回值 vi.mocked(getAuth).mockReturnValue({} as any); vi.mocked(checkAuthMethod).mockReset(); const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() }; vi.spyOn(AgentRuntime, 'initializeWithProvider').mockResolvedValue( new AgentRuntime(mockRuntime), ); const request = new Request(new URL('https://test.com'), { method: 'POST', body: JSON.stringify({ model: 'test-model' }), headers: { [LOBE_CHAT_AUTH_HEADER]: 'some-valid-token', [OAUTH_AUTHORIZED]: '1', }, }); await POST(request, { params: mockParams }); expect(checkAuthMethod).toBeCalledWith({ accessCode: 'test-access-code', apiKey: 'test-api-key', clerkAuth: {}, nextAuthAuthorized: true, }); }); it('should return InternalServerError error when throw a unknown error', async () => { const mockParams = Promise.resolve({ provider: 'test-provider' }); vi.mocked(getJWTPayload).mockRejectedValueOnce(new Error('unknown error')); const response = await POST(request, { params: mockParams }); expect(response.status).toBe(500); expect(await response.json()).toEqual({ body: { error: {}, provider: 'test-provider', }, errorType: 500, }); }); }); describe('chat', () => { it('should correctly handle chat completion with valid payload', async () => { vi.mocked(getJWTPayload).mockResolvedValueOnce({ accessCode: 'test-access-code', apiKey: 'test-api-key', azureApiVersion: 'v1', userId: 'abc', }); const mockParams = Promise.resolve({ provider: 'test-provider' }); const mockChatPayload = { message: 'Hello, world!' }; request = new Request(new URL('https://test.com'), { headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token' }, method: 'POST', body: JSON.stringify(mockChatPayload), }); const mockChatResponse: any = { success: true, message: 'Reply from agent' }; vi.spyOn(AgentRuntime.prototype, 'chat').mockResolvedValue(mockChatResponse); const response = await POST(request as unknown as Request, { params: mockParams }); expect(response).toEqual(mockChatResponse); expect(AgentRuntime.prototype.chat).toHaveBeenCalledWith(mockChatPayload, { user: 'abc', signal: expect.anything(), }); }); it('should return an error response when chat completion fails', async () => { // 设置 getJWTPayload 和 initAgentRuntimeWithUserPayload 的模拟返回值 vi.mocked(getJWTPayload).mockResolvedValueOnce({ accessCode: 'test-access-code', apiKey: 'test-api-key', azureApiVersion: 'v1', }); const mockParams = Promise.resolve({ provider: 'test-provider' }); const mockChatPayload = { message: 'Hello, world!' }; request = new Request(new URL('https://test.com'), { headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token' }, method: 'POST', body: JSON.stringify(mockChatPayload), }); const mockErrorResponse = { errorType: ChatErrorType.InternalServerError, errorMessage: 'Something went wrong', }; vi.spyOn(AgentRuntime.prototype, 'chat').mockRejectedValue(mockErrorResponse); const response = await POST(request, { params: mockParams }); expect(response.status).toBe(500); expect(await response.json()).toEqual({ body: { errorMessage: 'Something went wrong', error: { errorMessage: 'Something went wrong', errorType: 500, }, provider: 'test-provider', }, errorType: 500, }); }); }); });